Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: backend: dnnl: support select with binary primitive #2349

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Jiexin-Zheng
Copy link
Contributor

@Jiexin-Zheng Jiexin-Zheng commented Jan 7, 2025

Description

  1. cond input is defined for dnnl binary op
  2. For now, as primitive doesn't support broadcast for cond input, we use binary select primitive for non-broadcast case only, the lowering logic is: always lower select to binary primitive and then decide which impl path to use in pass decompose_select_to_multiple_binary_ops and decompose it to multiple binary ops if necessary.
  3. It's unsupported on GPU as binary primitive doesn't support either.

Performance

relative perf:

platform: Intel(R) Xeon(R) Platinum 8490H

case speedup
./tests/benchdnn/benchdnn --graph --mode=P --reset --in-shapes=1:1x12x128x128+2:1x12x128x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json 98.24%
./tests/benchdnn/benchdnn --graph --mode=P --reset --dt=bf16 --in-shapes=1:1x12x128x128+2:1x12x128x128 --case=complex_fusion/mha/MHA-distill_bert-inf-fp32-bs1.json 170.12%
./tests/benchdnn/benchdnn --graph --mode=P --reset --in-shapes=1:1x12x128x128+2:1x12x128x128 --case=complex_fusion/mha/MHA-distill_bert-inf-int8-bs1.json 140.41%

@Jiexin-Zheng Jiexin-Zheng added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Jan 7, 2025
@Jiexin-Zheng Jiexin-Zheng self-assigned this Jan 7, 2025
@Jiexin-Zheng Jiexin-Zheng requested a review from a team as a code owner January 7, 2025 14:48
@github-actions github-actions bot added the component:tests Codeowner: @oneapi-src/onednn-arch label Jan 7, 2025
@Jiexin-Zheng Jiexin-Zheng force-pushed the jiexin-zheng/main/select_op branch from 6be21c9 to 4ea4e67 Compare January 7, 2025 15:37
@Jiexin-Zheng
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

src/graph/backend/dnnl/passes/lower.cpp Outdated Show resolved Hide resolved
src/graph/backend/dnnl/kernels/sdp_decomp.cpp Outdated Show resolved Hide resolved
src/graph/backend/dnnl/passes/lower.cpp Outdated Show resolved Hide resolved
src/graph/backend/dnnl/passes/transform.cpp Outdated Show resolved Hide resolved
tests/gtests/graph/unit/backend/dnnl/test_sdp_decomp.cpp Outdated Show resolved Hide resolved
tests/gtests/graph/unit/backend/dnnl/test_select.cpp Outdated Show resolved Hide resolved
tests/gtests/graph/unit/utils.hpp Outdated Show resolved Hide resolved
tests/gtests/graph/unit/utils.hpp Outdated Show resolved Hide resolved
@Jiexin-Zheng Jiexin-Zheng force-pushed the jiexin-zheng/main/select_op branch from 4ea4e67 to b94bafa Compare January 8, 2025 12:52
@Jiexin-Zheng Jiexin-Zheng requested a review from a team as a code owner January 8, 2025 12:52
@Jiexin-Zheng Jiexin-Zheng changed the title graph: backend: dnnl: support select pattern and op with binary primitive graph: backend: dnnl: support select with binary primitive Jan 8, 2025
@Jiexin-Zheng
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

@Jiexin-Zheng Jiexin-Zheng force-pushed the jiexin-zheng/main/select_op branch from b94bafa to 458e748 Compare January 8, 2025 14:28
@Jiexin-Zheng
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

src/graph/backend/dnnl/passes/transform.hpp Outdated Show resolved Hide resolved
Copy link
Contributor

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any performance data to share?

src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp Outdated Show resolved Hide resolved
@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) {
int32_t src1_ndims = src1_lt.ndims;
int32_t target_ndims = std::max(src0_ndims, src1_ndims);
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims};
for (size_t i = 0; i < cur_op->num_inputs(); ++i) {
std::vector<size_t> input_indices = {0, 1};
for (auto i : input_indices) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Previously num_inputs() is 2 - 32 per the schema definition. Now the code only handles the first two?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is applied before postop fusion pass, so input number is always 2 before. For this PR, although binary select has three inputs, since cond dims has been promised to be the same that of src0 by pass decompose_select_to_binary_ops, we only need to unsqueeze src0 and src1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if cond dims has been promised to be the same of src0, then it should fall into the condition of if (in_ndims[i] == target_ndims) { continue; }, so no unsqueeze inserted. If this is the case, no need to limit the input_indices?

Copy link
Contributor Author

@Jiexin-Zheng Jiexin-Zheng Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_ndims only has two elements, the access for the third element is not legal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, then it seems the original code is designed for 2 elements

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is applied before postop fusion pass, so input number is always 2 before. For this PR, although binary select has three inputs, since cond dims has been promised to be the same that of src0 by pass decompose_select_to_binary_ops, we only need to unsqueeze src0 and src1.

This explanation looks suspicious as the code has quite a few assumption to work properly. You may need to at least add comment for that.
BTW: I feel for (size_t i : {0, 1}) { .... } should work without defining input_indices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: I keep the original for loop and make it skip the unsqueeze process when iterating the third input.

src/graph/backend/dnnl/passes/transform.cpp Outdated Show resolved Hide resolved
src/graph/backend/dnnl/passes/utils.cpp Outdated Show resolved Hide resolved
src/graph/interface/shape_infer.cpp Show resolved Hide resolved
@Jiexin-Zheng
Copy link
Contributor Author

Do you have any performance data to share?

Sure, I have attached it to the PR description.

@Jiexin-Zheng Jiexin-Zheng force-pushed the jiexin-zheng/main/select_op branch from 458e748 to f8262e0 Compare January 9, 2025 10:04
@Jiexin-Zheng
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

@Jiexin-Zheng Jiexin-Zheng force-pushed the jiexin-zheng/main/select_op branch from f8262e0 to 325bca9 Compare January 9, 2025 15:58
@Jiexin-Zheng
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) {
int32_t src1_ndims = src1_lt.ndims;
int32_t target_ndims = std::max(src0_ndims, src1_ndims);
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims};
for (size_t i = 0; i < cur_op->num_inputs(); ++i) {
std::vector<size_t> input_indices = {0, 1};
for (auto i : input_indices) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, then it seems the original code is designed for 2 elements

@Jiexin-Zheng Jiexin-Zheng force-pushed the jiexin-zheng/main/select_op branch from 325bca9 to 6694b8c Compare January 10, 2025 03:26
@Jiexin-Zheng
Copy link
Contributor Author

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

Copy link
Contributor

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please separate benchdnn inputs changes into a standalone commit.

@@ -2266,7 +2266,8 @@ status_t binary_canonicalization(std::shared_ptr<subgraph_t> &sg) {
int32_t src1_ndims = src1_lt.ndims;
int32_t target_ndims = std::max(src0_ndims, src1_ndims);
std::vector<int32_t> in_ndims {src0_ndims, src1_ndims};
for (size_t i = 0; i < cur_op->num_inputs(); ++i) {
std::vector<size_t> input_indices = {0, 1};
for (auto i : input_indices) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is applied before postop fusion pass, so input number is always 2 before. For this PR, although binary select has three inputs, since cond dims has been promised to be the same that of src0 by pass decompose_select_to_binary_ops, we only need to unsqueeze src0 and src1.

This explanation looks suspicious as the code has quite a few assumption to work properly. You may need to at least add comment for that.
BTW: I feel for (size_t i : {0, 1}) { .... } should work without defining input_indices.

src/graph/backend/dnnl/passes/utils.cpp Outdated Show resolved Hide resolved
src/graph/backend/dnnl/passes/utils.cpp Show resolved Hide resolved
@Jiexin-Zheng Jiexin-Zheng force-pushed the jiexin-zheng/main/select_op branch from 6694b8c to 66e2b1f Compare January 10, 2025 06:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants